{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# 数据预处理"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(847, 50, 42)\n",
      "(847, 2)\n"
     ]
    }
   ],
   "source": [
    "# 加载换道数据\n",
    "dataLC = np.load('../data/x_all_lc.npy')\n",
    "dataLK = np.load('../data/x_all_lk.npy')\n",
    "print(dataLC.shape)\n",
    "# 加载标签---注意！请将第一行“样本id,是否风险（1：高风险；0：低风险）”删除\n",
    "dataFlag = np.loadtxt('../data/risk_label.csv', delimiter=\",\", dtype=int)\n",
    "print(dataFlag.shape)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "outputs": [],
   "source": [
    "# 确定车辆位置关系 并且处理数据中的空值问题\n",
    "# 排序方法为：本车 -> 左前车 -> 前车 -> 右前车 -> 左后车 -> 后车 -> 右后车\n",
    "def DataSorting(sequence):\n",
    "    sortedSeq = None\n",
    "    for i in range(sequence.shape[0]):\n",
    "        data = sequence[i, :, :]\n",
    "        sortedData = np.zeros((50, 42))\n",
    "        targetCar = data[:, 0:6]\n",
    "        sortedData[:, 0:6] = targetCar\n",
    "        for j in range(1, 7):\n",
    "            flag = CarPosition(targetCar, data[:, j*6:j*6+6])\n",
    "            sortedData = AddCar(sortedData, data[:, j*6:j*6+6], flag)\n",
    "        sortedData = sortedData.reshape(1, 50, 42)\n",
    "        if sortedSeq is None:\n",
    "            sortedSeq = sortedData\n",
    "        else:\n",
    "            sortedSeq = np.append(sortedSeq, sortedData, axis=0)\n",
    "    return sortedSeq"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "outputs": [],
   "source": [
    "# 根据x,y坐标确定车辆的位置关系\n",
    "def CarPosition(car1, car2):\n",
    "    laneWidth = 1.5\n",
    "    # 如果是空数据，则直接返回\n",
    "    if car2[-1, 0] is None:\n",
    "        return False\n",
    "    # 右前车\n",
    "    if car2[-1, 0] > car1[-1, 0] and car2[-1, 1] - car1[-1, 1] > laneWidth:\n",
    "        return 3\n",
    "    # 左前车\n",
    "    elif car2[-1, 0] > car1[-1, 0] and car2[-1, 1] - car1[-1, 1] < -laneWidth:\n",
    "        return 1\n",
    "    # 前车\n",
    "    elif car2[-1, 0] > car1[-1, 0] and abs(car2[-1, 1] - car1[-1, 1]) <= laneWidth:\n",
    "        return 2\n",
    "    # 右后车\n",
    "    elif car2[-1, 0] < car1[-1, 0] and car2[-1, 1] - car1[-1, 1] > laneWidth:\n",
    "        return 6\n",
    "    # 左后车\n",
    "    elif car2[-1, 0] < car1[-1, 0] and car2[-1, 1] - car1[-1, 1] < -laneWidth:\n",
    "        return 4\n",
    "     # 后车\n",
    "    elif car2[-1, 0] < car1[-1, 0] and abs(car2[-1, 1] - car1[-1, 1]) <= laneWidth:\n",
    "        return 5\n",
    "\n",
    "    return False"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "outputs": [],
   "source": [
    "# 根据车辆位置关系，重构数据\n",
    "def AddCar(sequence, carData, carType):\n",
    "    # 左前车\n",
    "    if carType == 1:\n",
    "        sequence[:, 6:12] = carData\n",
    "    # 前车\n",
    "    elif carType == 2:\n",
    "        sequence[:, 12:18] = carData\n",
    "    # 右后车\n",
    "    elif carType == 3:\n",
    "        sequence[:, 18:24] = carData\n",
    "    # 左后车\n",
    "    elif carType == 4:\n",
    "        sequence[:, 24:30] = carData\n",
    "    # 后车\n",
    "    elif carType == 5:\n",
    "        sequence[:, 30:36] = carData\n",
    "    # 右后车\n",
    "    elif carType == 6:\n",
    "        sequence[:, 36:42] = carData\n",
    "    return sequence"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 数据重构，根据位置关系调整数据顺序：本车 -> 左前车 -> 前车 -> 右前车 -> 左后车 -> 后车 -> 右后车"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "outputs": [
    {
     "data": {
      "text/plain": "(847, 50, 42)"
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 换道车辆数据重构\n",
    "dataSortedLC = DataSorting(dataLC)\n",
    "dataSortedLC.shape"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "outputs": [
    {
     "data": {
      "text/plain": "(847, 50, 42)"
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 车道保持数据重构\n",
    "dataSortedLK = DataSorting(dataLK)\n",
    "dataSortedLK.shape"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 筛选出换道高风险数据与低风险数据"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "低风险数据维度： (795, 50, 42)\n",
      "高风险数据维度： (52, 50, 42)\n"
     ]
    }
   ],
   "source": [
    "lowRisk, highRisk = None, None\n",
    "for i in range(dataSortedLC.shape[0]):\n",
    "    dataI = dataSortedLC[i, :, :].reshape(1, 50, 42)\n",
    "    # 高风险\n",
    "    if dataFlag[i, -1] == 1:\n",
    "        if highRisk is None: highRisk = dataI\n",
    "        else: highRisk = np.append(highRisk, dataI, axis=0)\n",
    "    # 低风险\n",
    "    else:\n",
    "        if lowRisk is None: lowRisk = dataI\n",
    "        else: lowRisk = np.append(lowRisk, dataI, axis=0)\n",
    "print(\"低风险数据维度：\", lowRisk.shape)\n",
    "print(\"高风险数据维度：\", highRisk.shape)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "outputs": [],
   "source": [
    "## 提取x,y坐标以及x,y速度\n",
    "def GetPosSpeed(sequence):\n",
    "    outArray = None\n",
    "    for i in range(sequence.shape[0]):\n",
    "        itemData = None\n",
    "        for j in range(7):\n",
    "            data = sequence[i, :, j*6:j*6+4].reshape(1, sequence.shape[1], 1, 4)\n",
    "            if itemData is None: itemData = data\n",
    "            else: itemData = np.append(itemData, data, axis=2)\n",
    "        if outArray is None: outArray = itemData\n",
    "        else: outArray = np.append(outArray, itemData, axis=0)\n",
    "    return outArray"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 划分训练集与测试集，划分比例train:test = 7:3\n",
    "#### 各类数据的个数如下表所示\n",
    "|     | 换道低风险 | 换道高风险 | 车道保持 |\n",
    "|-----| ---- |-------| ---- |\n",
    "| 训练集 | 557 | 36    | 593 |\n",
    "| 测试集 | 238 | 16    | 254 |\n",
    "| 合计 | 795 | 52    | 847 |"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练集，低风险、高风险、车道保持： (557, 50, 7, 4) (36, 50, 7, 4) (593, 50, 7, 4)\n"
     ]
    }
   ],
   "source": [
    "# 训练集\n",
    "train_lowRisk = GetPosSpeed(lowRisk[:557, :, :])\n",
    "train_highRisk = GetPosSpeed(highRisk[:36, :, :])\n",
    "train_laneKeep = GetPosSpeed(dataSortedLK[:593, :, :])\n",
    "print(\"训练集，低风险、高风险、车道保持：\", train_lowRisk.shape, train_highRisk.shape, train_laneKeep.shape)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "测试集，低风险、高风险、车道保持： (238, 50, 7, 4) (16, 50, 7, 4) (254, 50, 7, 4)\n"
     ]
    }
   ],
   "source": [
    "# 测试集\n",
    "test_lowRisk = GetPosSpeed(lowRisk[557:, :, :])\n",
    "test_highRisk = GetPosSpeed(highRisk[36:, :, :])\n",
    "test_laneKeep = GetPosSpeed(dataSortedLK[593:, :, :])\n",
    "print(\"测试集，低风险、高风险、车道保持：\", test_lowRisk.shape, test_highRisk.shape, test_laneKeep.shape)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 训练集测试集数据保存"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "outputs": [],
   "source": [
    "np.save(\"../data/train_lowRisk.npy\", train_lowRisk)\n",
    "np.save(\"../data/train_highRisk.npy\", train_highRisk)\n",
    "np.save(\"../data/train_laneKeep.npy\", train_laneKeep)\n",
    "np.save(\"../data/test_lowRisk.npy\", test_lowRisk)\n",
    "np.save(\"../data/test_highRisk.npy\", test_highRisk)\n",
    "np.save(\"../data/test_laneKeep.npy\", test_laneKeep)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "outputs": [
    {
     "data": {
      "text/plain": "tensor(False)"
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "a = torch.tensor(train_laneKeep)\n",
    "torch.any(torch.isnan(a))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}